{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Not monitoring node memory since `psutil` is not installed. Install this with `pip install psutil` (or ray[debug]) to enable debugging of memory-related crashes.\n"
     ]
    }
   ],
   "source": [
    "from ast import literal_eval\n",
    "import ray\n",
    "import gym\n",
    "import retro\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from markov import sampleMarkov, createMarkov, randMarkov\n",
    "from support import getInitial, verifyTrajectory, install_games_from_rom_dir, frameToCell, action_set, trajectoryToGif\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import imageio\n",
    "imageio.plugins.freeimage.download()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Not updating worker name since `setproctitle` is not installed. Install this with `pip install setproctitle` (or ray[debug]) to enable monitoring of worker processes.\n",
      "Process STDOUT and STDERR is being redirected to /tmp/ray/session_2019-01-23_00-10-20_2192/logs.\n",
      "Waiting for redis server at 127.0.0.1:46177 to respond...\n",
      "Waiting for redis server at 127.0.0.1:10050 to respond...\n",
      "Starting Redis shard with 10.0 GB max memory.\n",
      "WARNING: The object store is using /tmp instead of /dev/shm because /dev/shm has only 2147483648 bytes available. This may slow down performance! You may be able to free up space by deleting files in /dev/shm or terminating any running plasma_store_server processes. If you are inside a Docker container, you may need to pass an argument with the flag '--shm-size' to 'docker run'.\n",
      "Starting the Plasma object store with 12.785916313 GB memory using /tmp.\n",
      "\n",
      "======================================================================\n",
      "View the web UI at http://localhost:8888/notebooks/ray_ui.ipynb?token=c43331ca232bb127a9007cdf2aa763cc09a8e9dc596ccedd\n",
      "======================================================================\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'node_ip_address': None,\n",
       " 'redis_address': '172.17.0.6:46177',\n",
       " 'object_store_address': '/tmp/ray/session_2019-01-23_00-10-20_2192/sockets/plasma_store',\n",
       " 'webui_url': 'http://localhost:8888/notebooks/ray_ui.ipynb?token=c43331ca232bb127a9007cdf2aa763cc09a8e9dc596ccedd',\n",
       " 'raylet_socket_name': '/tmp/ray/session_2019-01-23_00-10-20_2192/sockets/raylet'}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ray.init()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def winCondition(cell, info):\n",
    "    return info['level_end_bonus'] != 0\n",
    "\n",
    "def stopCondition(cell, info, step):\n",
    "    return step > 500 or winCondition(cell, info) or info['lives'] != 3\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@ray.remote\n",
    "class MasterActor(object):\n",
    "    def __init__(self, \n",
    "                 initialPolicy,\n",
    "                 initialCell,\n",
    "                 initialFitness,\n",
    "                 initialTrajectory,\n",
    "                 initialState):\n",
    "        self.best_trajectory = None\n",
    "        self.best_fitness = None\n",
    "        self.policy = initialPolicy\n",
    "        self.cells = [initialCell]\n",
    "        self.fitnesses = {initialCell:initialFitness}\n",
    "        self.cell_prob = {initialCell:1}\n",
    "        self.trajectories = {initialCell:initialTrajectory}\n",
    "        self.states = {initialCell:initialState}\n",
    "\n",
    "    def pushResult(self, cell, trajectory, state, info, step):\n",
    "        fitness = len(trajectory)\n",
    "        if cell in self.cells:\n",
    "            if fitness < self.fitnesses[cell]:\n",
    "                #Improvement to existing cell\n",
    "                self.fitnesses[cell] = fitness\n",
    "                self.trajectories[cell] = trajectory\n",
    "                self.states[cell] = state\n",
    "                self.cell_prob[cell] += 1\n",
    "        else:\n",
    "            if winCondition(cell, info):\n",
    "                if self.best_trajectory is None:\n",
    "                    #First time win\n",
    "                    self.best_trajectory = trajectory\n",
    "                    self.best_fitness = fitness\n",
    "                elif fitness<self.best_fitness:\n",
    "                    #Improvement win\n",
    "                    self.best_trajectory = trajectory       \n",
    "                    self.best_fitness = fitness\n",
    "            else:\n",
    "                #First time to this new cell\n",
    "                self.cells.append(cell)\n",
    "                self.fitnesses[cell] = fitness\n",
    "                self.trajectories[cell] = trajectory\n",
    "                self.states[cell] = state\n",
    "                self.cell_prob[cell] = 10\n",
    "\n",
    "    def pullCache(self):\n",
    "        return (self.policy, self.cells, self.fitnesses)\n",
    "    \n",
    "    def pullCell(self, cell):\n",
    "        return (self.states[cell], self.trajectories[cell])\n",
    "    \n",
    "    def pullBestTrajectory(self):\n",
    "        return self.best_trajectory\n",
    "    \n",
    "    def renormalizeCellProbs(self):\n",
    "        padd = .1\n",
    "        probsSum = np.array([self.cell_prob[c] for c in self.cells]).sum() + len(self.cells)*padd\n",
    "        for cell in self.cells:\n",
    "            self.cell_prob[cell]=(self.cell_prob[cell]+padd)/probsSum\n",
    "    \n",
    "    def updatePolicy(self):\n",
    "        if self.best_trajectory is None:\n",
    "            self.policy['weights'] = createMarkov(self.trajectories[self.cells[-1]],12)\n",
    "        else:\n",
    "            self.policy['weights'] = createMarkov(self.best_trajectory)\n",
    "    \n",
    "    def pullGo(self):\n",
    "        \n",
    "        normalized_cell_prob = np.array([self.cell_prob[c] for c in self.cells])\n",
    "        normalized_cell_prob = normalized_cell_prob/normalized_cell_prob.sum()\n",
    "    \n",
    "        goCell = self.cells[np.random.choice(np.arange(len(self.cells)), p = normalized_cell_prob )]\n",
    "    \n",
    "        return (self.states[goCell], self.trajectories[goCell])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "@ray.remote\n",
    "def GoExploreWorker(game, master):\n",
    "    env = retro.make(game)\n",
    "    env.reset()\n",
    "    while(True):\n",
    "        policy, cells, fitnesses = ray.get(master.pullCache.remote())\n",
    "        for _ in range(10):\n",
    "            state, trajectory = ray.get(master.pullGo.remote())\n",
    "\n",
    "            recurrent_state = None\n",
    "\n",
    "            if policy['type']=='markov':\n",
    "                recurrent_state = np.random.randint(12)\n",
    "\n",
    "            env.em.set_state(state)\n",
    "            step = 0\n",
    "            while(True):\n",
    "\n",
    "                action = None\n",
    "                if policy['type'] == 'random':\n",
    "                    action = np.random.randint(12)\n",
    "                if policy['type'] == 'markov':\n",
    "                    action = sampleMarkov(recurrent_state, policy['weights'])\n",
    "                    recurrent_state = action\n",
    "\n",
    "                observation, reward, done, info = env.step(action_set[action])\n",
    "                trajectory.append(action)\n",
    "                cell = frameToCell(observation, info)\n",
    "                fitness = len(trajectory)\n",
    "                state = env.em.get_state()\n",
    "                if cell in cells:\n",
    "                    if fitness < fitnesses[cell]:\n",
    "                        master.pushResult.remote(cell, trajectory.copy(), state, info, step)\n",
    "                else:\n",
    "                    master.pushResult.remote(cell, trajectory.copy(), state, info, step)\n",
    "                    cells.append(cell)\n",
    "                    fitnesses[cell]=fitness\n",
    "                    \n",
    "                if (stopCondition(cell,info,step)):\n",
    "                    break\n",
    "                step += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Importing SonicTheHedgehog-Genesis\n",
      "Imported 1 games\n"
     ]
    }
   ],
   "source": [
    "install_games_from_rom_dir('roms/')\n",
    "\n",
    "game = 'SonicTheHedgehog-Genesis'\n",
    "stateStr = 'GreenHillZone.Act1.state'\n",
    "\n",
    "initialPolicy = {'type':'markov', 'weights':randMarkov(10,12)}\n",
    "\n",
    "initialCell, initialState, initialTrajectory, initialFitness = getInitial(game, stateStr)\n",
    "\n",
    "NWorkers = 8\n",
    "\n",
    "master = MasterActor.remote(initialPolicy, initialCell, initialFitness, initialTrajectory, initialState)\n",
    "workers = [ GoExploreWorker.remote(game, master) for _ in range(NWorkers)]    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#time.sleep(10)\n",
    "#policy, cells, fitnesses = ray.get(master.pullCache.remote())\n",
    "#test_cell = cells[-1]\n",
    "#state, trajectory = ray.get(master.pullCell.remote(test_cell))\n",
    "#verifyTrajectory(game, stateStr, trajectory, state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time elapsed: 3.003373384475708, Cells: 1\n",
      "Time elapsed: 13.310152769088745, Cells: 148\n"
     ]
    }
   ],
   "source": [
    "start_time = time.time()\n",
    "i = 0\n",
    "while True:\n",
    "    time.sleep(1)\n",
    "    master.renormalizeCellProbs.remote()\n",
    "    master.updatePolicy.remote()\n",
    "    \n",
    "    \n",
    "    \n",
    "    if i%10==0:\n",
    "        message = ''\n",
    "        policy, cells, fitnesses = ray.get(master.pullCache.remote())\n",
    "        best_trajectory = ray.get(master.pullBestTrajectory.remote())\n",
    "        message += 'Time elapsed: ' + str(time.time()-start_time)\n",
    "        message += ', Cells: ' + str(len(cells))\n",
    "        if best_trajectory is not None:\n",
    "            message += ', Best trajectory length: ' + str(len(best_trajectory))\n",
    "            if i%200==0:\n",
    "                trajectoryToGif(game, stateStr, best_trajectory, True, 'Gameplay_FIN_'+str(len(best_trajectory))+'.gif')\n",
    "        else:\n",
    "            if i%200==0:\n",
    "                cell = cells[np.array([literal_eval(cell)[0] for cell in cells]).argsort()[-1]]\n",
    "                state, trajectory = ray.get(master.pullCell.remote(cell))\n",
    "                c = literal_eval(cell)\n",
    "                trajectoryToGif(game, stateStr, trajectory, True, 'Gameplay_FR_'+str(len(trajectory))+'-'+str(c[0])+'-'+str(c[1])+'-'+str(c[2])+'-'+str(c[3])+'.gif')  \n",
    "        print(message)\n",
    "    \n",
    "    i+=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
